{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "cat-dog-classification.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qGKs9H6LLXc5"
      },
      "source": [
        " This notebook will train a cat-dog classification model, using the pretrianed resnet18 as the feature extractor.\n",
        "\n",
        "The trainig code has been borrowd from [Pytorch vision transfer learning tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html), and the dataset is from taken from Kaggle. You can read more on how to use Kaggle datasets on Google Colab  [here](https://www.kaggle.com/general/74235)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6nLuTphYU8BX"
      },
      "source": [
        " ! pip install -q kaggle"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z7R7-6FSU-uM"
      },
      "source": [
        "# Upload kaggle.json file with kaggle API token here\n",
        "from google.colab import files\n",
        "uploaded = files.upload()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "U9GNKjPaVC6e"
      },
      "source": [
        "! mkdir ~/.kaggle"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W87-8iNhVFt7"
      },
      "source": [
        "! cp kaggle.json ~/.kaggle/"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Q2BECUfbVKjp"
      },
      "source": [
        "! chmod 600 ~/.kaggle/kaggle.json"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KPlzGXOfVLuI"
      },
      "source": [
        "!kaggle competitions download -c dogs-vs-cats"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WOd-Av5PVRiV"
      },
      "source": [
        "!unzip train.zip\n",
        "!unzip test1.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GDMqi2k5WPyH"
      },
      "source": [
        "import os\n",
        "train_files = os.listdir('./train')\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NDJozL1pX0VD"
      },
      "source": [
        "!mkdir dog-cat-dataset\n",
        "!mkdir dog-cat-dataset/dogs\n",
        "!mkdir dog-cat-dataset/cats\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V5As1TyGHsU-"
      },
      "source": [
        "!pwd"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ovd4sOHRovLw"
      },
      "source": [
        "Splitting the cat and dog images in the dataset."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mLUtBRIrXYcL"
      },
      "source": [
        "import shutil\n",
        "base_dir='/content/train'\n",
        "for filename in train_files:\n",
        "\n",
        "  category = filename.split('.')[0]\n",
        "  if category=='dog':\n",
        "    shutil.copyfile(os.path.join(base_dir,filename), os.path.join('/content/dog-cat-dataset/dogs',filename))\n",
        "  elif category=='cat':\n",
        "    shutil.copyfile(os.path.join(base_dir,filename), os.path.join('/content/dog-cat-dataset/cats',filename))\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nfgh3X7pJoKQ"
      },
      "source": [
        "%cd /content/dog-cat-dataset/cats\n",
        "!ls | wc -l \n",
        "%cd -"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iMgFbThJJ2wR"
      },
      "source": [
        "!mkdir ./dog-cat-dataset/test\n",
        "!mkdir ./dog-cat-dataset/test/dog\n",
        "!mkdir ./dog-cat-dataset/test/cat\n",
        "!mkdir ./dog-cat-dataset/val\n",
        "!mkdir ./dog-cat-dataset/val/dog\n",
        "!mkdir ./dog-cat-dataset/val/cat\n",
        "!mkdir ./dog-cat-dataset/train\n",
        "!mkdir ./dog-cat-dataset/train/dog\n",
        "!mkdir ./dog-cat-dataset/train/cat"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oHMnNTOHKtrK"
      },
      "source": [
        "%cd ./dog-cat-dataset/dogs/"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "k3Cjm409JpSv"
      },
      "source": [
        "!ls | wc -l"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A_0mw691MH53"
      },
      "source": [
        "Here, we split the dog and cat images in two train/val/test sets."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9LzEal-3LTex"
      },
      "source": [
        "#% cd dog-cat-dataset/dogs\n",
        "% ls | shuf -n 2250 | xargs -i mv {} ../val/dog\n",
        "% ls | shuf -n 2250 | xargs -i mv {} ../test/dog\n",
        "% ls | shuf -n 8000 | xargs -i mv {} ../train/dog\n",
        "% cd ../cats\n",
        "% ls | shuf -n 2250 | xargs -i mv {} ../val/cat\n",
        "% ls | shuf -n 2250 | xargs -i mv {} ../test/cat\n",
        "% ls | shuf -n 8000 | xargs -i mv {} ../train/cat\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WqnoatbJMZNE"
      },
      "source": [
        "%cd ../../"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hYpLE7fPORxf"
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.utils.data\n",
        "import torch.nn.functional as F\n",
        "import torchvision\n",
        "import torchvision.models as models\n",
        "from torchvision import transforms\n",
        "from PIL import Image\n",
        "import matplotlib.pyplot as plt\n",
        "from torch.optim import lr_scheduler\n",
        "from torchvision import datasets, models, transforms\n",
        "import os"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "21FLnitamRPJ"
      },
      "source": [
        "Setting the dataset, data loader and the image transforms for different sets."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NzoPeV3SMgZb"
      },
      "source": [
        "data_transforms = {\n",
        "    'train': transforms.Compose([\n",
        "        transforms.RandomResizedCrop(224),\n",
        "        transforms.RandomHorizontalFlip(),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
        "    ]),\n",
        "    'val': transforms.Compose([\n",
        "        transforms.Resize(256),\n",
        "        transforms.CenterCrop(224),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
        "    ]),\n",
        "    'test': transforms.Compose([\n",
        "        transforms.Resize(256),\n",
        "        transforms.CenterCrop(224),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
        "    ]),\n",
        "}\n",
        "\n",
        "data_dir = '/content/dog-cat-dataset/'\n",
        "image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),\n",
        "                                          data_transforms[x])\n",
        "                  for x in ['train', 'val', 'test']}\n",
        "dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,\n",
        "                                             shuffle=True, num_workers=2)\n",
        "              for x in ['train', 'val','test']}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "txbJ7AR5mmsY"
      },
      "source": [
        "Loading the pretrained resnet18 from torch-hub and changing the last layer to the number of features to 512, adding a dropout layer and output two classes of cat and dog. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jU1XUuA4OfgO"
      },
      "source": [
        "pretrained_model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)\n",
        "\n",
        "for name, param in pretrained_model.named_parameters():\n",
        "    if(\"bn\" not in name):\n",
        "        param.requires_grad = False\n",
        "        \n",
        "num_ftrs = pretrained_model.fc.in_features\n",
        "\n",
        "pretrained_model.fc = nn.Sequential(nn.Linear(pretrained_model.fc.in_features,512),\n",
        "                                  nn.ReLU(),\n",
        "                                  nn.Dropout(),\n",
        "                                  nn.Linear(512, 2))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LItRc9k0m5re"
      },
      "source": [
        "Setting the optimizer and learning rate scheduler."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uWPjrEO7Ophb"
      },
      "source": [
        "optimizer = optim.Adam(pretrained_model.parameters(), lr=0.001)\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "# Decay LR by a factor of 0.1 every 7 epochs\n",
        "exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KTsPsCounHsG"
      },
      "source": [
        "Define the training loop."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ePty2StyPAlY"
      },
      "source": [
        "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n",
        "    since = time.time()\n",
        "\n",
        "    best_model_wts = copy.deepcopy(model.state_dict())\n",
        "    best_acc = 0.0\n",
        "\n",
        "    for epoch in range(num_epochs):\n",
        "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
        "        print('-' * 10)\n",
        "\n",
        "        # Each epoch has a training and validation phase\n",
        "        for phase in ['train', 'val']:\n",
        "            if phase == 'train':\n",
        "                model.train()  # Set model to training mode\n",
        "            else:\n",
        "                model.eval()   # Set model to evaluate mode\n",
        "\n",
        "            running_loss = 0.0\n",
        "            running_corrects = 0\n",
        "\n",
        "            # Iterate over data.\n",
        "            for inputs, labels in dataloaders[phase]:\n",
        "                inputs = inputs.to(device)\n",
        "                labels = labels.to(device)\n",
        "\n",
        "                # zero the parameter gradients\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "                # forward\n",
        "                # track history if only in train\n",
        "                with torch.set_grad_enabled(phase == 'train'):\n",
        "                    outputs = model(inputs)\n",
        "                    _, preds = torch.max(outputs, 1)\n",
        "                    loss = criterion(outputs, labels)\n",
        "\n",
        "                    # backward + optimize only if in training phase\n",
        "                    if phase == 'train':\n",
        "                        loss.backward()\n",
        "                        optimizer.step()\n",
        "\n",
        "                # statistics\n",
        "                running_loss += loss.item() * inputs.size(0)\n",
        "                running_corrects += torch.sum(preds == labels.data)\n",
        "            if phase == 'train':\n",
        "                scheduler.step()\n",
        "\n",
        "            epoch_loss = running_loss / dataset_sizes[phase]\n",
        "            epoch_acc = running_corrects.double() / dataset_sizes[phase]\n",
        "\n",
        "            print('{} Loss: {:.4f} Acc: {:.4f}'.format(\n",
        "                phase, epoch_loss, epoch_acc))\n",
        "\n",
        "            # deep copy the model\n",
        "            if phase == 'val' and epoch_acc > best_acc:\n",
        "                best_acc = epoch_acc\n",
        "                best_model_wts = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        print()\n",
        "\n",
        "    time_elapsed = time.time() - since\n",
        "    print('Training complete in {:.0f}m {:.0f}s'.format(\n",
        "        time_elapsed // 60, time_elapsed % 60))\n",
        "    print('Best val Acc: {:4f}'.format(best_acc))\n",
        "\n",
        "    # load best model weights\n",
        "    model.load_state_dict(best_model_wts)\n",
        "    return model"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "i3DnJMxhPFU6"
      },
      "source": [
        "import time\n",
        "import copy"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EZgw4fh5PLhT"
      },
      "source": [
        "if torch.cuda.is_available():\n",
        "    device = torch.device(\"cuda\") \n",
        "else:\n",
        "    device = torch.device(\"cpu\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "F2YzeoVbPH1_"
      },
      "source": [
        "pretrained_model.to(device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vMRuhvbGPT-2"
      },
      "source": [
        "dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n",
        "class_names = image_datasets['train'].classes"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XFeEKECjPXr3"
      },
      "source": [
        "print(class_names)\n",
        "print(dataset_sizes)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_csze4BinO4R"
      },
      "source": [
        "Run the training for 2 epochs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IDTD62aDPatj"
      },
      "source": [
        "pretrained_model = train_model(pretrained_model, criterion, optimizer, exp_lr_scheduler,\n",
        "                       num_epochs=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "REjBvqCKmH6y"
      },
      "source": [
        "Testing the model pefromance on the test set. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e2ETh1uwV8cQ"
      },
      "source": [
        "def test_model(model):\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    with torch.no_grad():\n",
        "        for images, labels in dataloaders['test']:\n",
        "            images, labels = images.to(device), labels.to(device)\n",
        "            outputs = model(images)\n",
        "            _, predicted = torch.max(outputs.data, 1)\n",
        "            total += labels.size(0)\n",
        "            correct += (predicted == labels).sum().item()\n",
        "    print('correct: {:d}  total: {:d}'.format(correct, total))\n",
        "    print('accuracy = {:f}'.format(correct / total))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P6E_doT9WR61"
      },
      "source": [
        "test_model(pretrained_model)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4zkHC8b0WsU4"
      },
      "source": [
        "torch.save(pretrained_model.state_dict(), \"./cat_dog_classification.pth\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_8BhEM09nUzE"
      },
      "source": [
        "Load the trained/saved model for sanity check and testing on the test dataset. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zhxjNiXIYBKV"
      },
      "source": [
        "pretrained_model = torch.hub.load('pytorch/vision', 'resnet18')\n",
        "pretrained_model.fc = nn.Sequential(nn.Linear(pretrained_model.fc.in_features,512),nn.ReLU(), nn.Dropout(), nn.Linear(512, 2))\n",
        "pretrained_model.load_state_dict(torch.load('./cat_dog_classification.pth'))\n",
        "pretrained_model.eval()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nRGkjPxQgRZJ"
      },
      "source": [
        "pretrained_model.to(device)\n",
        "test_model(pretrained_model)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rACobIG9lv96"
      },
      "source": [
        "Running some prediction tests. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZG1r5PseOlSP"
      },
      "source": [
        "def prediction(model, filename):\n",
        "    labels = class_names\n",
        "    img = Image.open(filename)\n",
        "    img = data_transforms['test'](img)\n",
        "    img = img.unsqueeze(0)\n",
        "    prediction = model(img.to(device))\n",
        "    print(prediction)\n",
        "    prediction = prediction.argmax()\n",
        "    print(labels)\n",
        "    print(labels[prediction])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "47_Gdw09iIZj"
      },
      "source": [
        "prediction(pretrained_model, '/content/dog-cat-dataset/test/cat/cat.100.jpg')\n",
        "prediction(pretrained_model, '/content/dog-cat-dataset/test/dog/dog.10033.jpg')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VZbxP1UUndmH"
      },
      "source": [
        "Downloading the trained model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "t7Xhu7jSidnp"
      },
      "source": [
        "from google.colab import files\n",
        "files.download('./cat_dog_classification.pth') "
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}